package hex.glrm;

import hex.DataInfo;
import hex.ModelMetrics;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import hex.glrm.GLRMModel.GLRMParameters;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.Log;

import java.util.Random;
import java.util.concurrent.ExecutionException;

public class GLRMCategoricalTest extends TestUtil {
  public final double TOLERANCE = 1e-6;
  @BeforeClass public static void setup() { stall_till_cloudsize(1); }

  private static String colFormat(String[] cols, String format) {
    int[] idx = new int[cols.length];
    for(int i = 0; i < idx.length; i++) idx[i] = i;
    return colFormat(cols, format, idx);
  }
  private static String colFormat(String[] cols, String format, int[] idx) {
    StringBuilder sb = new StringBuilder();
    for(int i = 0; i < cols.length; i++) sb.append(String.format(format, cols[idx[i]]));
    sb.append("\n");
    return sb.toString();
  }

  private static String colExpFormat(String[] cols, String[][] domains, String format) {
    int[] idx = new int[cols.length];
    for(int i = 0; i < idx.length; i++) idx[i] = i;
    return colExpFormat(cols, domains, format, idx);
  }

  private static String colExpFormat(String[] cols, String[][] domains, String format, int[] idx) {
    StringBuilder sb = new StringBuilder();
    for(int i = 0; i < domains.length; i++) {
      int c = idx[i];
      if(domains[c] == null)
        sb.append(String.format(format, cols[c]));
      else {
        for(int j = 0; j < domains[c].length; j++)
          sb.append(String.format(format, domains[c][j]));
      }
    }
    sb.append("\n");
    return sb.toString();
  }

  @Test public void testCategoricalIris() throws InterruptedException, ExecutionException {
    GLRMModel model = null;
    Frame train = null;

    try {
      train = parseTestFile(Key.make("iris.hex"), "smalldata/iris/iris_wheader.csv");
      GLRMParameters parms = new GLRMParameters();
      parms._train = train._key;
      parms._k = 4;
      parms._loss = GlrmLoss.Absolute;
      parms._init = GlrmInitialization.SVD;
      parms._transform = DataInfo.TransformType.NONE;
      parms._recover_svd = true;
      parms._max_iterations = 1000;

      model = new GLRM(parms).trainModel().get();
      Log.info("Iteration " + model._output._iterations + ": Objective value = " + model._output._objective);
      model.score(train).delete();
      ModelMetricsGLRM mm = (ModelMetricsGLRM)ModelMetrics.getFromDKV(model, train);
      Log.info("Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr);
    } finally {
      if (train != null) train.delete();
      if (model != null) model.delete();
    }
  }

  @Test public void testCategoricalProstate() throws InterruptedException, ExecutionException {
    GLRMModel model = null;
    Frame train = null;
    final int[] cats = new int[]{1,3,4,5};    // Categoricals: CAPSULE, RACE, DPROS, DCAPS

    try {
      Scope.enter();
      train = parseTestFile(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for(int i = 0; i < cats.length; i++)
        Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
      train.remove("ID").remove();
      DKV.put(train._key, train);

      GLRMParameters parms = new GLRMParameters();
      parms._train = train._key;
      parms._k = 8;
      parms._gamma_x = parms._gamma_y = 0.1;
      parms._regularization_x = GlrmRegularizer.Quadratic;
      parms._regularization_y = GlrmRegularizer.Quadratic;
      parms._init = GlrmInitialization.PlusPlus;
      parms._transform = DataInfo.TransformType.STANDARDIZE;
      parms._recover_svd = false;
      parms._max_iterations = 200;

      model = new GLRM(parms).trainModel().get();
      Log.info("Iteration " + model._output._iterations + ": Objective value = " + model._output._objective);
      model.score(train).delete();
      ModelMetricsGLRM mm = (ModelMetricsGLRM)ModelMetrics.getFromDKV(model, train);
      Log.info("Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr);
    } finally {
      if (train != null) train.delete();
      if (model != null) model.delete();
      Scope.exit();
    }
  }

  @Test public void testLosses() throws InterruptedException, ExecutionException {
    long seed = 0xDECAF;
    Random rng = new Random(seed);
    Frame train = null;
    final int[] cats = new int[]{1,3,4,5};    // Categoricals: CAPSULE, RACE, DPROS, DCAPS
    final GlrmRegularizer[] regs = new GlrmRegularizer[] {
            GlrmRegularizer.Quadratic,
            GlrmRegularizer.L1,
            GlrmRegularizer.NonNegative,
            GlrmRegularizer.OneSparse,
            GlrmRegularizer.UnitOneSparse,
            GlrmRegularizer.Simplex
    };

    Scope.enter();
    try {
      train = parseTestFile(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for(int i = 0; i < cats.length; i++)
        Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
      train.remove("ID").remove();
      DKV.put(train._key, train);

      for(GlrmLoss loss : new GlrmLoss[] {
              GlrmLoss.Quadratic,
              GlrmLoss.Absolute,
              GlrmLoss.Huber,
              GlrmLoss.Poisson
      }) {
        for(GlrmLoss multiloss : new GlrmLoss[] {
                GlrmLoss.Categorical,
                GlrmLoss.Ordinal
        }) {
          GLRMModel model = null;
          try {
            Scope.enter();
            long myseed = rng.nextLong();
            Log.info("GLRM using seed = " + myseed);

            GLRMParameters parms = new GLRMParameters();
            parms._train = train._key;
            parms._transform = DataInfo.TransformType.NONE;
            parms._k = 5;
            parms._loss = loss;
            parms._multi_loss = multiloss;
            parms._init = GlrmInitialization.SVD;
            parms._regularization_x = regs[rng.nextInt(regs.length)];
            parms._regularization_y = regs[rng.nextInt(regs.length)];
            parms._gamma_x = Math.abs(rng.nextDouble());
            parms._gamma_y = Math.abs(rng.nextDouble());
            parms._recover_svd = false;
            parms._seed = myseed;
            parms._verbose = false;
            parms._max_iterations = 500;

            model = new GLRM(parms).trainModel().get();
            Log.info("Iteration " + model._output._iterations + ": Objective value = " + model._output._objective);
            model.score(train).delete();
            ModelMetricsGLRM mm = (ModelMetricsGLRM)ModelMetrics.getFromDKV(model, train);
            Log.info("Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr);
          } finally {
            if (model != null) model.delete();
            Scope.exit();
          }
        }
      }
    } finally {
      if(train != null) train.delete();
      Scope.exit();
    }
  }

  @Test public void testSetColumnLossCats() throws InterruptedException, ExecutionException {
    GLRMModel model = null;
    Frame train = null;
    final int[] cats = new int[]{1,3,4,5};    // Categoricals: CAPSULE, RACE, DPROS, DCAPS

    Scope.enter();
    try {
      train = parseTestFile(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for(int i = 0; i < cats.length; i++)
        Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec()));
      train.remove("ID").remove();
      DKV.put(train._key, train);

      GLRMParameters parms = new GLRMParameters();
      parms._train = train._key;
      parms._k = 12;
      parms._loss = GlrmLoss.Quadratic;
      parms._multi_loss = GlrmLoss.Categorical;
      parms._loss_by_col = new GlrmLoss[] { GlrmLoss.Ordinal, GlrmLoss.Poisson, GlrmLoss.Absolute};
      parms._loss_by_col_idx = new int[] { 3 /* DPROS */, 1 /* AGE */, 6 /* VOL */ };
      parms._init = GlrmInitialization.PlusPlus;
      parms._min_step_size = 1e-5;
      parms._recover_svd = false;
      parms._max_iterations = 2000;

      model = new GLRM(parms).trainModel().get();
      Log.info("Iteration " + model._output._iterations + ": Objective value = " + model._output._objective);
      GLRMTest.checkLossbyCol(parms, model);

      model.score(train).delete();
      ModelMetricsGLRM mm = (ModelMetricsGLRM)ModelMetrics.getFromDKV(model, train);
      Log.info("Numeric Sum of Squared Error = " + mm._numerr + "\tCategorical Misclassification Error = " + mm._caterr);

    } finally {
      if (train != null) train.delete();
      if (model != null) model.delete();
      Scope.exit();
    }
  }

  @Test public void testExpandCatsIris() throws InterruptedException, ExecutionException {
    double[][] iris = ard(ard(6.3, 2.5, 4.9, 1.5, 1),
            ard(5.7, 2.8, 4.5, 1.3, 1),
            ard(5.6, 2.8, 4.9, 2.0, 2),
            ard(5.0, 3.4, 1.6, 0.4, 0),
            ard(6.0, 2.2, 5.0, 1.5, 2));
    double[][] iris_expandR = ard(ard(0, 1, 0, 6.3, 2.5, 4.9, 1.5),
            ard(0, 1, 0, 5.7, 2.8, 4.5, 1.3),
            ard(0, 0, 1, 5.6, 2.8, 4.9, 2.0),
            ard(1, 0, 0, 5.0, 3.4, 1.6, 0.4),
            ard(0, 0, 1, 6.0, 2.2, 5.0, 1.5));
    String[] iris_cols = new String[] {"sepal_len", "sepal_wid", "petal_len", "petal_wid", "class"};
    String[][] iris_domains = new String[][] { null, null, null, null, new String[] {"setosa", "versicolor", "virginica"} };

    Frame fr = null;
    try {
      fr = parseTestFile(Key.make("iris.hex"), "smalldata/iris/iris_wheader.csv");
      DataInfo dinfo = new DataInfo(fr, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, false, /* weights */ false, /* offset */ false, /* fold */ false);

      Log.info("Original matrix:\n" + colFormat(iris_cols, "%8.7s") + ArrayUtils.pprint(iris));
      double[][] iris_perm = ArrayUtils.permuteCols(iris, dinfo._permutation);
      Log.info("Permuted matrix:\n" + colFormat(iris_cols, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(iris_perm));

      double[][] iris_exp = GLRM.expandCats(iris_perm, dinfo);
      Log.info("Expanded matrix:\n" + colExpFormat(iris_cols, iris_domains, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(iris_exp));
      Assert.assertArrayEquals(iris_expandR, iris_exp);
    } finally {
      if (fr != null) fr.delete();
    }
  }

  @Test public void testExpandCatsProstate() throws InterruptedException, ExecutionException {
    double[][] prostate = ard(ard(0, 71, 1, 0, 0,  4.8, 14.0, 7),
            ard(1, 70, 1, 1, 0,  8.4, 21.8, 5),
            ard(0, 73, 1, 3, 0, 10.0, 27.4, 6),
            ard(1, 68, 1, 0, 0,  6.7, 16.7, 6));
    double[][] pros_expandR = ard(ard(1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 71,  4.8, 14.0, 7),
            ard(0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 70,  8.4, 21.8, 5),
            ard(0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 73, 10.0, 27.4, 6),
            ard(1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 68,  6.7, 16.7, 6));
    String[] pros_cols = new String[]{"Capsule", "Age", "Race", "Dpros", "Dcaps", "PSA", "Vol", "Gleason"};
    String[][] pros_domains = new String[][]{new String[]{"No", "Yes"}, null, new String[]{"Other", "White", "Black"},
            new String[]{"None", "UniLeft", "UniRight", "Bilobar"}, new String[]{"No", "Yes"}, null, null, null};
    final int[] cats = new int[]{1,3,4,5};    // Categoricals: CAPSULE, RACE, DPROS, DCAPS

    Frame fr = null;
    try {
      Scope.enter();
      fr = parseTestFile(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for(int i = 0; i < cats.length; i++)
        Scope.track(fr.replace(cats[i], fr.vec(cats[i]).toCategoricalVec()));
      fr.remove("ID").remove();
      DKV.put(fr._key, fr);
      DataInfo dinfo = new DataInfo(fr, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, false, /* weights */ false, /* offset */ false, /* fold */ false);

      Log.info("Original matrix:\n" + colFormat(pros_cols, "%8.7s") + ArrayUtils.pprint(prostate));
      double[][] pros_perm = ArrayUtils.permuteCols(prostate, dinfo._permutation);
      Log.info("Permuted matrix:\n" + colFormat(pros_cols, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(pros_perm));

      double[][] pros_exp = GLRM.expandCats(pros_perm, dinfo);
      Log.info("Expanded matrix:\n" + colExpFormat(pros_cols, pros_domains, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(pros_exp));
      Assert.assertArrayEquals(pros_expandR, pros_exp);
    } finally {
      if (fr != null) fr.delete();
      Scope.exit();
    }
  }
}
